import torch

# create test data
a = torch.randn(2,3)
b = torch.randn(1,3)
c = torch.randn(1)

# 自动扩展
print((a+b).shape)
print((a+c).shape)

# 手动broadcast
print(b.expand_as(a).shape)
print(c.expand_as(a).shape)